{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6. Recurrent Neural Networks and Language Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf\n",
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture9.pdf\n",
    "* http://colah.github.io/posts/2015-08-Understanding-LSTMs/\n",
    "* https://github.com/pytorch/examples/tree/master/word_language_model\n",
    "* https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/language_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter, OrderedDict\n",
    "import nltk\n",
    "from copy import deepcopy\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, to_index):\n",
    "    idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[\"<unk>\"], seq))\n",
    "    return LongTensor(idxs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Penn TreeBank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_ptb_dataset(filename, word2index=None):\n",
    "    corpus = open(filename, 'r', encoding='utf-8').readlines()\n",
    "    corpus = flatten([co.strip().split() + ['</s>'] for co in corpus])\n",
    "    \n",
    "    if word2index == None:\n",
    "        vocab = list(set(corpus))\n",
    "        word2index = {'<unk>': 0}\n",
    "        for vo in vocab:\n",
    "            if word2index.get(vo) is None:\n",
    "                word2index[vo] = len(word2index)\n",
    "    \n",
    "    return prepare_sequence(corpus, word2index), word2index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# borrowed code from https://github.com/pytorch/examples/tree/master/word_language_model\n",
    "\n",
    "def batchify(data, bsz):\n",
    "    # Work out how cleanly we can divide the dataset into bsz parts.\n",
    "    nbatch = data.size(0) // bsz\n",
    "    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
    "    data = data.narrow(0, 0, nbatch * bsz)\n",
    "    # Evenly divide the data across the bsz batches.\n",
    "    data = data.view(bsz, -1).contiguous()\n",
    "    if USE_CUDA:\n",
    "        data = data.cuda()\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(data, seq_length):\n",
    "     for i in range(0, data.size(1) - seq_length, seq_length):\n",
    "        inputs = Variable(data[:, i: i + seq_length])\n",
    "        targets = Variable(data[:, (i + 1): (i + 1) + seq_length].contiguous())\n",
    "        yield (inputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data, word2index = prepare_ptb_dataset('../dataset/ptb/ptb.train.txt',)\n",
    "dev_data , _ = prepare_ptb_dataset('../dataset/ptb/ptb.valid.txt', word2index)\n",
    "test_data, _ = prepare_ptb_dataset('../dataset/ptb/ptb.test.txt', word2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(word2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "index2word = {v:k for k, v in word2index.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/06.rnnlm-architecture.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture8.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class LanguageModel(nn.Module): \n",
    "    def __init__(self, vocab_size, embedding_size, hidden_size, n_layers=1, dropout_p=0.5):\n",
    "\n",
    "        super(LanguageModel, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        self.hidden_size = hidden_size\n",
    "        self.embed = nn.Embedding(vocab_size, embedding_size)\n",
    "        self.rnn = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)\n",
    "        self.linear = nn.Linear(hidden_size, vocab_size)\n",
    "        self.dropout = nn.Dropout(dropout_p)\n",
    "        \n",
    "    def init_weight(self):\n",
    "        self.embed.weight = nn.init.xavier_uniform(self.embed.weight)\n",
    "        self.linear.weight = nn.init.xavier_uniform(self.linear.weight)\n",
    "        self.linear.bias.data.fill_(0)\n",
    "        \n",
    "    def init_hidden(self,batch_size):\n",
    "        hidden = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
    "        context = Variable(torch.zeros(self.n_layers,batch_size,self.hidden_size))\n",
    "        return (hidden.cuda(), context.cuda()) if USE_CUDA else (hidden, context)\n",
    "    \n",
    "    def detach_hidden(self, hiddens):\n",
    "        return tuple([hidden.detach() for hidden in hiddens])\n",
    "    \n",
    "    def forward(self, inputs, hidden, is_training=False): \n",
    "\n",
    "        embeds = self.embed(inputs)\n",
    "        if is_training:\n",
    "            embeds = self.dropout(embeds)\n",
    "        out,hidden = self.rnn(embeds, hidden)\n",
    "        return self.linear(out.contiguous().view(out.size(0) * out.size(1), -1)), hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It takes for a while..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "EMBED_SIZE = 128\n",
    "HIDDEN_SIZE = 1024\n",
    "NUM_LAYER = 1\n",
    "LR = 0.01\n",
    "SEQ_LENGTH = 30 # for bptt\n",
    "BATCH_SIZE = 20\n",
    "EPOCH = 40\n",
    "RESCHEDULED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = batchify(train_data, BATCH_SIZE)\n",
    "dev_data = batchify(dev_data, BATCH_SIZE//2)\n",
    "test_data = batchify(test_data, BATCH_SIZE//2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = LanguageModel(len(word2index), EMBED_SIZE, HIDDEN_SIZE, NUM_LAYER, 0.5)\n",
    "model.init_weight() \n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[00/40] mean_loss : 9.45, Perplexity : 12712.23\n",
      "[00/40] mean_loss : 5.88, Perplexity : 358.21\n",
      "[00/40] mean_loss : 5.55, Perplexity : 256.44\n",
      "[01/40] mean_loss : 5.38, Perplexity : 217.46\n",
      "[01/40] mean_loss : 5.21, Perplexity : 182.41\n",
      "[01/40] mean_loss : 5.10, Perplexity : 164.39\n",
      "[02/40] mean_loss : 5.08, Perplexity : 160.87\n",
      "[02/40] mean_loss : 4.99, Perplexity : 147.18\n",
      "[02/40] mean_loss : 4.92, Perplexity : 136.52\n",
      "[03/40] mean_loss : 4.92, Perplexity : 136.64\n",
      "[03/40] mean_loss : 4.86, Perplexity : 129.32\n",
      "[03/40] mean_loss : 4.80, Perplexity : 121.46\n",
      "[04/40] mean_loss : 4.80, Perplexity : 121.91\n",
      "[04/40] mean_loss : 4.77, Perplexity : 117.64\n",
      "[04/40] mean_loss : 4.71, Perplexity : 111.22\n",
      "[05/40] mean_loss : 4.72, Perplexity : 112.01\n",
      "[05/40] mean_loss : 4.70, Perplexity : 109.46\n",
      "[05/40] mean_loss : 4.64, Perplexity : 103.96\n",
      "[06/40] mean_loss : 4.66, Perplexity : 105.25\n",
      "[06/40] mean_loss : 4.64, Perplexity : 103.63\n",
      "[06/40] mean_loss : 4.60, Perplexity : 99.00\n",
      "[07/40] mean_loss : 4.60, Perplexity : 99.89\n",
      "[07/40] mean_loss : 4.59, Perplexity : 98.97\n",
      "[07/40] mean_loss : 4.55, Perplexity : 94.97\n",
      "[08/40] mean_loss : 4.56, Perplexity : 95.54\n",
      "[08/40] mean_loss : 4.56, Perplexity : 95.67\n",
      "[08/40] mean_loss : 4.52, Perplexity : 91.98\n",
      "[09/40] mean_loss : 4.53, Perplexity : 92.61\n",
      "[09/40] mean_loss : 4.53, Perplexity : 92.79\n",
      "[09/40] mean_loss : 4.50, Perplexity : 89.63\n",
      "[10/40] mean_loss : 4.50, Perplexity : 90.13\n",
      "[10/40] mean_loss : 4.50, Perplexity : 90.19\n",
      "[10/40] mean_loss : 4.47, Perplexity : 87.11\n",
      "[11/40] mean_loss : 4.48, Perplexity : 88.11\n",
      "[11/40] mean_loss : 4.48, Perplexity : 88.26\n",
      "[11/40] mean_loss : 4.45, Perplexity : 86.05\n",
      "[12/40] mean_loss : 4.46, Perplexity : 86.81\n",
      "[12/40] mean_loss : 4.47, Perplexity : 87.03\n",
      "[12/40] mean_loss : 4.43, Perplexity : 84.04\n",
      "[13/40] mean_loss : 4.45, Perplexity : 85.27\n",
      "[13/40] mean_loss : 4.45, Perplexity : 85.83\n",
      "[13/40] mean_loss : 4.42, Perplexity : 83.33\n",
      "[14/40] mean_loss : 4.43, Perplexity : 84.15\n",
      "[14/40] mean_loss : 4.43, Perplexity : 84.31\n",
      "[14/40] mean_loss : 4.41, Perplexity : 82.29\n",
      "[15/40] mean_loss : 4.43, Perplexity : 83.82\n",
      "[15/40] mean_loss : 4.43, Perplexity : 83.70\n",
      "[15/40] mean_loss : 4.40, Perplexity : 81.59\n",
      "[16/40] mean_loss : 4.42, Perplexity : 83.06\n",
      "[16/40] mean_loss : 4.42, Perplexity : 83.29\n",
      "[16/40] mean_loss : 4.39, Perplexity : 80.89\n",
      "[17/40] mean_loss : 4.41, Perplexity : 82.44\n",
      "[17/40] mean_loss : 4.41, Perplexity : 82.51\n",
      "[17/40] mean_loss : 4.39, Perplexity : 80.59\n",
      "[18/40] mean_loss : 4.40, Perplexity : 81.59\n",
      "[18/40] mean_loss : 4.41, Perplexity : 82.21\n",
      "[18/40] mean_loss : 4.38, Perplexity : 79.87\n",
      "[19/40] mean_loss : 4.40, Perplexity : 81.43\n",
      "[19/40] mean_loss : 4.40, Perplexity : 81.67\n",
      "[19/40] mean_loss : 4.37, Perplexity : 79.28\n",
      "[20/40] mean_loss : 4.40, Perplexity : 81.18\n",
      "[20/40] mean_loss : 4.40, Perplexity : 81.17\n",
      "[20/40] mean_loss : 4.37, Perplexity : 79.11\n",
      "[21/40] mean_loss : 4.40, Perplexity : 81.44\n",
      "[21/40] mean_loss : 4.34, Perplexity : 76.43\n",
      "[21/40] mean_loss : 4.21, Perplexity : 67.17\n",
      "[22/40] mean_loss : 4.26, Perplexity : 70.84\n",
      "[22/40] mean_loss : 4.26, Perplexity : 70.75\n",
      "[22/40] mean_loss : 4.17, Perplexity : 64.99\n",
      "[23/40] mean_loss : 4.22, Perplexity : 68.36\n",
      "[23/40] mean_loss : 4.22, Perplexity : 67.82\n",
      "[23/40] mean_loss : 4.15, Perplexity : 63.74\n",
      "[24/40] mean_loss : 4.20, Perplexity : 66.66\n",
      "[24/40] mean_loss : 4.20, Perplexity : 66.43\n",
      "[24/40] mean_loss : 4.14, Perplexity : 62.85\n",
      "[25/40] mean_loss : 4.18, Perplexity : 65.53\n",
      "[25/40] mean_loss : 4.17, Perplexity : 64.99\n",
      "[25/40] mean_loss : 4.13, Perplexity : 61.94\n",
      "[26/40] mean_loss : 4.17, Perplexity : 64.61\n",
      "[26/40] mean_loss : 4.16, Perplexity : 64.34\n",
      "[26/40] mean_loss : 4.12, Perplexity : 61.27\n",
      "[27/40] mean_loss : 4.15, Perplexity : 63.73\n",
      "[27/40] mean_loss : 4.15, Perplexity : 63.32\n",
      "[27/40] mean_loss : 4.11, Perplexity : 60.87\n",
      "[28/40] mean_loss : 4.14, Perplexity : 62.96\n",
      "[28/40] mean_loss : 4.14, Perplexity : 63.01\n",
      "[28/40] mean_loss : 4.10, Perplexity : 60.33\n",
      "[29/40] mean_loss : 4.14, Perplexity : 62.54\n",
      "[29/40] mean_loss : 4.13, Perplexity : 62.36\n",
      "[29/40] mean_loss : 4.10, Perplexity : 60.06\n",
      "[30/40] mean_loss : 4.13, Perplexity : 62.05\n",
      "[30/40] mean_loss : 4.13, Perplexity : 61.91\n",
      "[30/40] mean_loss : 4.09, Perplexity : 59.46\n",
      "[31/40] mean_loss : 4.12, Perplexity : 61.45\n",
      "[31/40] mean_loss : 4.11, Perplexity : 61.24\n",
      "[31/40] mean_loss : 4.08, Perplexity : 59.12\n",
      "[32/40] mean_loss : 4.11, Perplexity : 61.03\n",
      "[32/40] mean_loss : 4.11, Perplexity : 60.88\n",
      "[32/40] mean_loss : 4.07, Perplexity : 58.69\n",
      "[33/40] mean_loss : 4.11, Perplexity : 60.71\n",
      "[33/40] mean_loss : 4.10, Perplexity : 60.57\n",
      "[33/40] mean_loss : 4.07, Perplexity : 58.38\n",
      "[34/40] mean_loss : 4.10, Perplexity : 60.33\n",
      "[34/40] mean_loss : 4.10, Perplexity : 60.23\n",
      "[34/40] mean_loss : 4.06, Perplexity : 58.06\n",
      "[35/40] mean_loss : 4.09, Perplexity : 60.00\n",
      "[35/40] mean_loss : 4.09, Perplexity : 59.74\n",
      "[35/40] mean_loss : 4.06, Perplexity : 57.75\n",
      "[36/40] mean_loss : 4.09, Perplexity : 59.58\n",
      "[36/40] mean_loss : 4.09, Perplexity : 59.47\n",
      "[36/40] mean_loss : 4.05, Perplexity : 57.59\n",
      "[37/40] mean_loss : 4.08, Perplexity : 59.30\n",
      "[37/40] mean_loss : 4.08, Perplexity : 59.11\n",
      "[37/40] mean_loss : 4.05, Perplexity : 57.11\n",
      "[38/40] mean_loss : 4.08, Perplexity : 58.98\n",
      "[38/40] mean_loss : 4.07, Perplexity : 58.70\n",
      "[38/40] mean_loss : 4.04, Perplexity : 57.10\n",
      "[39/40] mean_loss : 4.07, Perplexity : 58.79\n",
      "[39/40] mean_loss : 4.07, Perplexity : 58.58\n",
      "[39/40] mean_loss : 4.04, Perplexity : 56.79\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    total_loss = 0\n",
    "    losses = []\n",
    "    hidden = model.init_hidden(BATCH_SIZE)\n",
    "    for i,batch in enumerate(getBatch(train_data, SEQ_LENGTH)):\n",
    "        inputs, targets = batch\n",
    "        hidden = model.detach_hidden(hidden)\n",
    "        model.zero_grad()\n",
    "        preds, hidden = model(inputs, hidden, True)\n",
    "\n",
    "        loss = loss_function(preds, targets.view(-1))\n",
    "        losses.append(loss.data[0])\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm(model.parameters(), 0.5) # gradient clipping\n",
    "        optimizer.step()\n",
    "\n",
    "        if i > 0 and i % 500 == 0:\n",
    "            print(\"[%02d/%d] mean_loss : %0.2f, Perplexity : %0.2f\" % (epoch,EPOCH, np.mean(losses), np.exp(np.mean(losses))))\n",
    "            losses = []\n",
    "        \n",
    "    # learning rate anealing\n",
    "    # You can use http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate\n",
    "    if RESCHEDULED == False and epoch == EPOCH//2:\n",
    "        LR *= 0.1\n",
    "        optimizer = optim.Adam(model.parameters(), lr=LR)\n",
    "        RESCHEDULED = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Perpelexity : 155.89\n"
     ]
    }
   ],
   "source": [
    "total_loss = 0\n",
    "hidden = model.init_hidden(BATCH_SIZE//2)\n",
    "for batch in getBatch(test_data, SEQ_LENGTH):\n",
    "    inputs,targets = batch\n",
    "        \n",
    "    hidden = model.detach_hidden(hidden)\n",
    "    model.zero_grad()\n",
    "    preds, hidden = model(inputs, hidden)\n",
    "    total_loss += inputs.size(1) * loss_function(preds, targets.view(-1)).data\n",
    "\n",
    "total_loss = total_loss[0]/test_data.size(1)\n",
    "print(\"Test Perpelexity : %5.2f\" % (np.exp(total_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Further topics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://arxiv.org/pdf/1609.07843.pdf\">Pointer Sentinel Mixture Models</a>\n",
    "* <a href=\"https://arxiv.org/pdf/1708.02182\">Regularizing and Optimizing LSTM Language Models</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
